# -*- coding: utf-8 -*-
"""Untitled27.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1B9SZoZrU-cBPcNbOTgAVuCVQKlpso2-R
"""

# -*- coding: utf-8 -*-
"""
Core library for matrix product approximation experiments.
Contains functions for matrix generation, algorithms, bounds,
and specific logic for Experiment 3: Scalability Analysis.
"""

# --- Imports ---
import numpy as np
import scipy
import scipy.linalg
import cvxpy as cp
import matplotlib.pyplot as plt
import pandas as pd
from tqdm.notebook import tqdm # Using tqdm.notebook for better notebook integration
import warnings
from typing import Dict, Any, List, Optional, Tuple
import itertools
import math
import traceback
import time
import os
import json # For saving results if needed

# --- Global Style Configuration ---
# Ensures plots are created in 'plots' subdirectory
os.makedirs("plots", exist_ok=True)
os.makedirs("results", exist_ok=True)

# Suppress common warnings for cleaner output
warnings.filterwarnings("ignore", category=RuntimeWarning, module="cvxpy")
warnings.filterwarnings("ignore", category=UserWarning, message="Solution may be inaccurate.*")
warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in divide")
warnings.filterwarnings("ignore", category=RuntimeWarning, message="invalid value encountered in scalar divide")
warnings.filterwarnings("ignore", category=RuntimeWarning, message="Mean of empty slice")
warnings.filterwarnings("ignore", category=FutureWarning, message="Using `tqdm.autonotebook.tqdm` in notebook mode.*")


# --- USER-PROVIDED Publication Quality Plotting Settings (rcParams) ---
plt.rcParams.update({
    'font.size': 20,               # Base font size
    'axes.titlesize': 24,          # Font size for subplot titles
    'axes.labelsize': 22,          # Font size for x and y labels
    'xtick.labelsize': 20,         # Font size for x-axis tick labels
    'ytick.labelsize': 20,         # Font size for y-axis tick labels
    'legend.fontsize': 20,         # Font size for legends
    'figure.titlesize': 26,        # Font size for the figure's suptitle
    'figure.figsize': (18, 8),     # Default figure size
    'figure.dpi': 150,             # Higher DPI for better quality
    'savefig.dpi': 300,            # Even higher DPI for saved figures
    'lines.linewidth': 3,          # Default line width
    'lines.markersize': 12,        # Default marker size
    'axes.linewidth': 1.5,         # Width of the axes lines
    'grid.linewidth': 1.0,         # Width of the grid lines
    'axes.grid': True,             # Show grid by default
    'grid.alpha': 0.3,             # Grid transparency
    'axes.titleweight': 'bold',    # Bold subplot titles
    'axes.labelweight': 'bold',    # Bold axis labels
    'figure.titleweight': 'bold',  # Bold figure title
    'mathtext.default': 'regular', # Math text style
    'mathtext.fontset': 'cm',      # Computer Modern math font
    'figure.facecolor': 'white',   # Ensure figure background is white
    'axes.facecolor': 'white',     # Ensure axes background is white
    'savefig.facecolor': 'white'   # Ensure saved figure background is white
})

# --- USER-PROVIDED High-quality plot styles (IMPROVED_STYLES_BASE) ---
IMPROVED_STYLES_BASE = {
    'Optimal Error v_k^*': {
        'color': 'gold', 'marker': '*', 'linestyle': '-', 'label': r'Optimal $v_k^*$',
        'lw': 4.0, 'markersize': 16, 'zorder': 10, 'markeredgewidth': 1.5, 'markeredgecolor': 'black'
    },
    'Your Bound (QP CVXPY Best)': {
        'color': 'black', 'marker': 'o', 'linestyle': '-', 'label': 'Bound (QP Best)',
        'lw': 3.5, 'markersize': 12, 'zorder': 9, 'markeredgewidth': 1.0
    },
    'Your Bound (QP Analytical)': {
        'color': 'dimgrey', 'marker': '^', 'linestyle': ':', 'label': 'Bound (QP Approx)',
        'lw': 3.0, 'markersize': 12, 'zorder': 8, 'markeredgewidth': 1.0
    },
    'Your Bound (Binary)': {
        'color': 'darkgrey', 'marker': 's', 'linestyle': '--', 'label': 'Bound (Binary)',
        'lw': 3.0, 'markersize': 12, 'zorder': 7, 'markeredgewidth': 1.0
    },
    'Bound (Leverage Score Exp.)': {
        'color': 'deepskyblue', 'marker': 'D', 'linestyle': '-.', 'label': 'Bound (Lev. Score Exp.)',
        'lw': 3.0, 'markersize': 12, 'zorder': 6, 'markeredgewidth': 1.0, 'markeredgecolor': 'navy'
    },
    'Bound (Sketching Simple)': {
        'color': 'sandybrown', 'marker': 'P', 'linestyle': ':', 'label': 'Bound (Sketching Simple)',
        'lw': 3.0, 'markersize': 12, 'zorder': 5, 'markeredgewidth': 1.0, 'markeredgecolor': 'saddlebrown'
    },
    'Error Leverage Score (Actual)': {
        'color': 'blue', 'marker': 'x', 'linestyle': '-', 'label': 'Leverage Score Sampling',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 4, 'markeredgewidth': 2.0
    },
    'Error CountSketch (Actual)': {
        'color': 'orange', 'marker': 'd', 'linestyle': '--', 'label': 'CountSketch',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 3, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkorange'
    },
    'Error SRHT (Actual)': {
        'color': 'red', 'marker': 'v', 'linestyle': '-.', 'label': 'SRHT',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 2, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkred'
    },
    'Error Gaussian (Actual)': {
        'color': 'darkviolet', 'marker': '<', 'linestyle': ':', 'label': 'Gaussian Proj.',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 1, 'markeredgewidth': 1.5, 'markeredgecolor': 'indigo'
    },
    'Error Greedy OMP (Actual)': {
        'color': 'forestgreen', 'marker': '>', 'linestyle': '-', 'label': 'Greedy OMP',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 0, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkgreen'
    },
}

# --- ADAPTED Plotting Styles for Experiment (Based on User's IMPROVED_STYLES with key renaming) ---
ADAPTED_EXPERIMENT_STYLES = {
    'Optimal Error v_k^*': IMPROVED_STYLES_BASE['Optimal Error v_k^*'],
    'Bound (QP Aux)': {**IMPROVED_STYLES_BASE['Your Bound (QP CVXPY Best)'], 'label': 'Bound (QP Aux)'},
    'Bound (Scaled Id)': {**IMPROVED_STYLES_BASE['Your Bound (QP Analytical)'], 'label': 'Bound (Scaled Id)'},
    'Your Bound (Binary)': IMPROVED_STYLES_BASE['Your Bound (Binary)'],
    'Bound (Sampling)': {**IMPROVED_STYLES_BASE['Bound (Leverage Score Exp.)'], 'label': 'Bound (Sampling)'},
    'Bound (Sketching)': {**IMPROVED_STYLES_BASE['Bound (Sketching Simple)'], 'label': 'Bound (Sketching)'},
    'Optimal Sampling': {**IMPROVED_STYLES_BASE['Error Leverage Score (Actual)'], 'label': 'Optimal Sampling'},
    'CountSketch': IMPROVED_STYLES_BASE['Error CountSketch (Actual)'],
    'SRHT': IMPROVED_STYLES_BASE['Error SRHT (Actual)'],
    'Gaussian Proj.': IMPROVED_STYLES_BASE['Error Gaussian (Actual)'],
    'Greedy OMP': IMPROVED_STYLES_BASE['Error Greedy OMP (Actual)'],
}

# --- Basic Helper Functions ---
def frob_norm_sq(M: np.ndarray) -> float:
    M = np.asarray(M, dtype=np.float64)
    return np.linalg.norm(M, 'fro')**2

def calculate_rho_g(A: np.ndarray, B: np.ndarray) -> float:
    try:
        A_f64 = np.asarray(A, dtype=np.float64)
        B_f64 = np.asarray(B, dtype=np.float64)
        if A_f64.shape[1] != B_f64.shape[1]:
             raise ValueError(f"Dimension mismatch: A_cols={A_f64.shape[1]}, B_cols={B_f64.shape[1]}")
        m, n_dim = A_f64.shape
        p, n2 = B_f64.shape
        if n_dim == 0: return 0.0

        AtA = A_f64.T @ A_f64
        BtB = B_f64.T @ B_f64
        G = AtA * BtB
        trace_G = np.trace(G)
        sum_G = np.sum(G)

        if sum_G <= 1e-12:
             if np.linalg.norm(G, 'fro') < 1e-12: return 0.0
             if trace_G > 1e-12: return np.inf
             return np.inf
        rho = trace_G / sum_G
        return max(0, rho)
    except Exception as e:
        warnings.warn(f"Error calculating Rho_G: {e}", RuntimeWarning)
        return np.nan

# --- Matrix Generation ---
def generate_matrices(m: int, p: int, n_dim: int,
                      cancellation_pairs: int = 0,
                      noise_level: float = 0.0,
                      seed: Optional[int] = None,
                      distribution: str = 'gaussian') -> Tuple[np.ndarray, np.ndarray]:
    if seed is not None:
        np.random.seed(seed)

    if distribution.lower() == 'gaussian':
        A = np.random.randn(m, n_dim)
        B = np.random.randn(p, n_dim)
    elif distribution.lower() == 'uniform':
        A = np.random.rand(m, n_dim) * 2 - 1
        B = np.random.rand(p, n_dim) * 2 - 1
    else:
        raise ValueError(f"Unsupported distribution: {distribution}. Choose 'gaussian' or 'uniform'.")

    if cancellation_pairs > 0 and n_dim >= 2 * cancellation_pairs:
        cancellation_pairs = min(cancellation_pairs, n_dim // 2)
        indices = np.random.choice(n_dim, 2 * cancellation_pairs, replace=False)
        for i in range(cancellation_pairs):
            idx1, idx2 = indices[2*i], indices[2*i+1]
            A[:, idx2] = A[:, idx1]
            B[:, idx2] = -B[:, idx1]
            scale_factor_A = np.random.uniform(0.7, 1.3)
            scale_factor_B = np.random.uniform(0.7, 1.3)
            A[:, idx1] *= scale_factor_A; A[:, idx2] *= scale_factor_A
            B[:, idx1] *= scale_factor_B; B[:, idx2] *= scale_factor_B

    if noise_level > 0:
        std_A = np.std(A)
        std_B = np.std(B)
        A += np.random.normal(0, noise_level * (std_A if std_A > 1e-9 else 1.0), size=(m, n_dim))
        B += np.random.normal(0, noise_level * (std_B if std_B > 1e-9 else 1.0), size=(p, n_dim))
    return A, B

# --- USER'S BOUND FUNCTION ---
def compute_theoretical_bounds(A_in: np.ndarray, B_in: np.ndarray, k: int, n_dim: int,
                               frob_ABt_exact_sq_for_ratio: float) -> Tuple[Dict[str, float], Dict[str, float]]:
    A = np.asarray(A_in, dtype=np.float64)
    B = np.asarray(B_in, dtype=np.float64)

    ratios_sq = {
        'Your Bound (Binary)': np.nan,
        'Your Bound (QP Analytical)': np.nan,
        'Your Bound (QP CVXPY Best)': np.nan,
    }
    timings = {
        'time_binary': 0.0,
        'time_qp_analytical': 0.0,
        'time_qp_cvxpy_best': 0.0,
    }

    if frob_ABt_exact_sq_for_ratio < 1e-20:
        for key in ratios_sq: ratios_sq[key] = 0.0 if k==n_dim else (1.0 if k==0 else np.nan)
        return ratios_sq, timings

    start_time_binary = time.time()
    try:
        current_frob_ABt_sq_val = frob_norm_sq(A @ B.T)
        current_AtA = A.T @ A
        current_BtB = B.T @ B
        current_G = current_AtA * current_BtB
        current_TrG = np.trace(current_G)

        if n_dim > 1:
            alpha_k = k / (n_dim - 1) if (n_dim -1) > 0 else (1.0 if k > 0 else 0.0)
            binary_bound_sq_value = max(0, (1.0 - k / n_dim) *
                                        ((1.0 - alpha_k) * current_frob_ABt_sq_val + alpha_k * current_TrG))
            ratios_sq['Your Bound (Binary)'] = binary_bound_sq_value / frob_ABt_exact_sq_for_ratio
        elif k == n_dim: ratios_sq['Your Bound (Binary)'] = 0.0
        elif k == 0 and n_dim > 0 : ratios_sq['Your Bound (Binary)'] = 1.0
        elif n_dim == 1 and k == 1: ratios_sq['Your Bound (Binary)'] = 0.0
        elif n_dim == 1 and k == 0: ratios_sq['Your Bound (Binary)'] = 1.0
        else: ratios_sq['Your Bound (Binary)'] = np.nan
    except Exception: ratios_sq['Your Bound (Binary)'] = np.nan
    timings['time_binary'] = time.time() - start_time_binary

    start_time_qp_analytical = time.time()
    try:
        current_frob_ABt_sq_val = frob_norm_sq(A @ B.T)
        current_AtA = A.T @ A
        current_BtB = B.T @ B
        current_G = current_AtA * current_BtB
        current_TrG = np.trace(current_G)

        if n_dim > 1 and current_frob_ABt_sq_val > 1e-12:
            beta_k = (k - 1) / (n_dim - 1) if (n_dim-1) > 0 else (0.0 if k <=1 else 1.0)
            denominator = (beta_k + (1 - beta_k) * (current_TrG / current_frob_ABt_sq_val))
            if abs(denominator) > 1e-12:
                gamma = 1.0 / denominator
                ratios_sq['Your Bound (QP Analytical)'] = max(0, 1.0 - k * gamma / n_dim)
            else: ratios_sq['Your Bound (QP Analytical)'] = np.nan
        elif k == n_dim: ratios_sq['Your Bound (QP Analytical)'] = 0.0
        elif k == 0 and n_dim > 0: ratios_sq['Your Bound (QP Analytical)'] = 1.0
        elif n_dim == 1 and k == 1: ratios_sq['Your Bound (QP Analytical)'] = 0.0
        elif n_dim == 1 and k == 0: ratios_sq['Your Bound (QP Analytical)'] = 1.0
        else: ratios_sq['Your Bound (QP Analytical)'] = np.nan
    except Exception: ratios_sq['Your Bound (QP Analytical)'] = np.nan
    timings['time_qp_analytical'] = time.time() - start_time_qp_analytical

    start_time_qp_cvxpy = time.time()
    try:
        current_frob_ABt_sq_val = frob_norm_sq(A @ B.T)
        current_AtA = A.T @ A
        current_BtB = B.T @ B
        current_G_matrix = current_AtA * current_BtB
        current_q_vec_val = current_G_matrix @ np.ones(n_dim) if n_dim > 0 else np.array([])

        if k > 0 and n_dim > 1:
            beta_k = (k - 1) / (n_dim - 1) if (n_dim-1) > 0 else (0.0 if k <=1 else 1.0)
            G_hat_k = beta_k * current_G_matrix + (1 - beta_k) * np.diag(np.diag(current_G_matrix))
            y = cp.Variable(n_dim)
            constraints = [y >= 0]
            objective = cp.Minimize(0.5 * cp.quad_form(y, G_hat_k) - current_q_vec_val.T @ y)
            prob = cp.Problem(objective, constraints)
            prob.solve(solver=cp.SCS, verbose=False, eps=1e-7, max_iters=7500)
            if prob.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
                vk_bound_sq_value = max(0, current_frob_ABt_sq_val + (k / n_dim) * 2.0 * prob.value)
                ratios_sq['Your Bound (QP CVXPY Best)'] = vk_bound_sq_value / frob_ABt_exact_sq_for_ratio
            else: ratios_sq['Your Bound (QP CVXPY Best)'] = np.nan
        elif k == 0 and n_dim > 0: ratios_sq['Your Bound (QP CVXPY Best)'] = 1.0
        elif k == n_dim: ratios_sq['Your Bound (QP CVXPY Best)'] = 0.0
        elif n_dim == 1 and k == 1: ratios_sq['Your Bound (QP CVXPY Best)'] = 0.0
        elif n_dim == 1 and k == 0: ratios_sq['Your Bound (QP CVXPY Best)'] = 1.0
        else: ratios_sq['Your Bound (QP CVXPY Best)'] = np.nan
    except Exception: ratios_sq['Your Bound (QP CVXPY Best)'] = np.nan
    timings['time_qp_cvxpy_best'] = time.time() - start_time_qp_cvxpy

    return ratios_sq, timings

# --- STANDARD BOUNDS FUNCTION ---
def compute_standard_bounds(A_in: np.ndarray, B_in: np.ndarray, k: int, n_dim: int,
                            frob_ABt_exact_sq_for_ratio: float) -> Tuple[Dict[str, float], Dict[str, float]]:
    A = np.asarray(A_in, dtype=np.float64)
    B = np.asarray(B_in, dtype=np.float64)

    ratios = {'Bound (Leverage Score Exp.)': np.nan, 'Bound (Sketching Simple)': np.nan}
    timings = {'time_lev_score_exp': 0.0, 'time_sketch_simple': 0.0}

    if k <= 0 or n_dim == 0:
        if k == 0 and frob_ABt_exact_sq_for_ratio > 1e-20 and n_dim > 0:
             ratios['Bound (Leverage Score Exp.)'] = 1.0
             ratios['Bound (Sketching Simple)'] = np.inf
        return ratios, timings

    if frob_ABt_exact_sq_for_ratio < 1e-20:
        for key in ratios: ratios[key] = 0.0
        return ratios, timings

    start_time_lev = time.time()
    try:
        current_frob_ABt_sq_val = frob_norm_sq(A @ B.T)
        norms_A = np.linalg.norm(A, axis=0)
        norms_B = np.linalg.norm(B, axis=0)
        sum_prod_norms = np.sum(norms_A * norms_B)

        if k > 0:
            expected_error_sq_value = (sum_prod_norms**2 - current_frob_ABt_sq_val) / k
            ratios['Bound (Leverage Score Exp.)'] = max(0, expected_error_sq_value) / frob_ABt_exact_sq_for_ratio
        else: ratios['Bound (Leverage Score Exp.)'] = np.nan
    except Exception as e:
        warnings.warn(f"Failed Leverage Score bound k={k}, n={n_dim}: {e}", RuntimeWarning)
        ratios['Bound (Leverage Score Exp.)'] = np.nan
    timings['time_lev_score_exp'] = time.time() - start_time_lev

    start_time_sketch = time.time()
    try:
        frob_A_sq = frob_norm_sq(A)
        frob_B_sq = frob_norm_sq(B)
        if k > 0:
            sketching_bound_sq_value = (frob_A_sq * frob_B_sq) / k
            ratios['Bound (Sketching Simple)'] = max(0, sketching_bound_sq_value) / frob_ABt_exact_sq_for_ratio
        else: ratios['Bound (Sketching Simple)'] = np.nan
    except Exception as e:
        warnings.warn(f"Failed Simple Sketching bound k={k}, n={n_dim}: {e}", RuntimeWarning)
        ratios['Bound (Sketching Simple)'] = np.nan
    timings['time_sketch_simple'] = time.time() - start_time_sketch

    return ratios, timings

# --- Algorithm Implementations ---
def run_leverage_score_sampling(A: np.ndarray, B: np.ndarray, k: int, optimal: bool = True, replacement: bool = False) -> np.ndarray:
    m, n_dim = A.shape; p_rows_B, n_dim_B = B.shape
    if n_dim != n_dim_B: raise ValueError(f"A_cols={n_dim}, B_cols={n_dim_B} mismatch")
    if k <= 0: raise ValueError(f"k must be >0 for Leverage Score Sampling, got {k}")
    if n_dim == 0: return np.zeros((m, p_rows_B), dtype=A.dtype)

    actual_k = k
    if not replacement and k > n_dim: actual_k = n_dim

    if optimal:
        norms_A_euc = np.linalg.norm(A, axis=0)
        norms_B_euc = np.linalg.norm(B, axis=0)
        lev_scores = norms_A_euc * norms_B_euc
        total_lev_score = np.sum(lev_scores)
        probs = (lev_scores / total_lev_score) if total_lev_score >= 1e-12 else (np.ones(n_dim) / n_dim)
    else: probs = np.ones(n_dim) / n_dim

    probs = np.maximum(probs, 1e-12); probs /= probs.sum()

    selected_indices = np.random.choice(n_dim, size=actual_k, replace=replacement, p=probs)
    scaling = 1.0 / np.sqrt(actual_k * probs[selected_indices])

    A_reduced = A[:, selected_indices] * scaling
    B_reduced = B[:, selected_indices] * scaling
    return A_reduced @ B_reduced.T

def run_countsketch(A: np.ndarray, B: np.ndarray, k: int) -> np.ndarray:
    m, n_dim = A.shape; p_rows_B, n_dim_B = B.shape
    if n_dim != n_dim_B: raise ValueError(f"A_cols={n_dim}, B_cols={n_dim_B} mismatch")
    if k <= 0: raise ValueError("k must be >0 for CountSketch")
    if n_dim == 0: return np.zeros((m, p_rows_B), dtype=A.dtype)

    h = np.random.randint(0, k, size=n_dim)
    g = np.random.choice([-1.0, 1.0], size=n_dim)

    SA = np.zeros((m, k), dtype=A.dtype)
    SB = np.zeros((p_rows_B, k), dtype=B.dtype)

    for j in range(n_dim):
        hash_idx, sign = h[j], g[j]
        SA[:, hash_idx] += sign * A[:, j]
        SB[:, hash_idx] += sign * B[:, j]
    return SA @ SB.T

def run_gaussian_projection(A: np.ndarray, B: np.ndarray, k: int) -> np.ndarray:
    m, n_dim = A.shape; p_rows_B, n_dim_B = B.shape
    if n_dim != n_dim_B: raise ValueError(f"A_cols={n_dim}, B_cols={n_dim_B} mismatch")
    if k <= 0: raise ValueError("k must be >0 for Gaussian Projection")
    if n_dim == 0: return np.zeros((m, p_rows_B), dtype=A.dtype)

    S_matrix = np.random.randn(k, n_dim) / np.sqrt(k)

    A_proj = A @ S_matrix.T
    B_proj = B @ S_matrix.T
    return A_proj @ B_proj.T

def run_greedy_selection_omp(A: np.ndarray, B: np.ndarray, k: int, ABt_exact: Optional[np.ndarray] = None) -> np.ndarray:
    m, n_dim = A.shape; p_rows_B, n_dim_B = B.shape
    if n_dim != n_dim_B: raise ValueError(f"A_cols={n_dim}, B_cols={n_dim_B} mismatch")
    if not (1 <= k <= n_dim): raise ValueError(f"k={k} must be 1 <= k <= n={n_dim} for Greedy OMP")
    if n_dim == 0: return np.zeros((m, p_rows_B), dtype=A.dtype)

    if ABt_exact is None: ABt_exact = A @ B.T

    A_f64, B_f64 = A.astype(np.float64), B.astype(np.float64)
    ABt_exact_f64 = ABt_exact.astype(np.float64)

    selected_indices = []
    residual = ABt_exact_f64.copy()
    current_approx_omp = np.zeros_like(ABt_exact_f64)

    for _ in range(k):
        if len(selected_indices) == n_dim: break

        correlations = np.zeros(n_dim)
        for j in range(n_dim):
            if j in selected_indices:
                correlations[j] = -np.inf
                continue
            correlations[j] = np.abs(A_f64[:,j].T @ residual @ B_f64[:,j])

        best_idx = np.argmax(correlations)
        if correlations[best_idx] < 1e-9 or best_idx in selected_indices : break

        selected_indices.append(best_idx)
        A_selected = A_f64[:, selected_indices]
        B_selected = B_f64[:, selected_indices]
        current_approx_omp = A_selected @ B_selected.T
        residual = ABt_exact_f64 - current_approx_omp

    if not selected_indices: return np.zeros((m, p_rows_B), dtype=A.dtype)
    return A[:, selected_indices] @ B[:, selected_indices].T


def fast_walsh_hadamard_transform_manual(X: np.ndarray, axis: int = -1) -> np.ndarray:
    Y = np.asarray(X, dtype=float)
    n_ax = Y.shape[axis]
    original_axis = axis
    if axis < 0: axis = Y.ndim + axis

    if not (n_ax > 0 and (n_ax & (n_ax - 1) == 0)):
        raise ValueError(f"Input size {n_ax} along axis {original_axis} must be a power of 2.")
    if n_ax == 1: return Y

    idx_even = [slice(None)] * Y.ndim; idx_odd = [slice(None)] * Y.ndim
    idx_even[axis] = slice(None, None, 2); idx_odd[axis] = slice(1, None, 2)

    H_even = fast_walsh_hadamard_transform_manual(Y[tuple(idx_even)], axis=axis)
    H_odd = fast_walsh_hadamard_transform_manual(Y[tuple(idx_odd)], axis=axis)

    result = np.empty_like(Y)
    idx_first_half = [slice(None)] * Y.ndim; idx_second_half = [slice(None)] * Y.ndim
    idx_first_half[axis] = slice(0, n_ax // 2)
    idx_second_half[axis] = slice(n_ax // 2, n_ax)

    result[tuple(idx_first_half)] = H_even + H_odd
    result[tuple(idx_second_half)] = H_even - H_odd
    return result

def pad_matrix(A_matrix: np.ndarray, axis: int = 1) -> Tuple[np.ndarray, int]:
    n_orig = A_matrix.shape[axis]
    if n_orig == 0: return A_matrix, 0

    if n_orig > 0 and (n_orig & (n_orig - 1) == 0):
        next_pow_2 = n_orig
    else:
        next_pow_2 = 1 << (n_orig - 1).bit_length() if n_orig > 0 else 0

    if next_pow_2 > n_orig:
        pad_width = next_pow_2 - n_orig
        padding_spec = [(0, 0)] * A_matrix.ndim
        padding_spec[axis] = (0, pad_width)
        A_padded = np.pad(A_matrix, pad_width=padding_spec, mode='constant', constant_values=0)
        return A_padded, n_orig
    return A_matrix, n_orig

def run_srht_new(A: np.ndarray, B: np.ndarray, k: int, optimal_sampling: bool = False) -> np.ndarray:
    m, n_dim = A.shape; p_rows_B, n_dim_B = B.shape
    if n_dim != n_dim_B: raise ValueError(f"A_cols={n_dim}, B_cols={n_dim_B} mismatch")
    if k <= 0: raise ValueError(f"k must be >0 for SRHT, got {k}")
    if n_dim == 0: return np.zeros((m, p_rows_B), dtype=A.dtype)

    A_padded, n_orig_A = pad_matrix(A, axis=1)
    B_padded, n_orig_B = pad_matrix(B, axis=1)
    N_padded = A_padded.shape[1]

    if N_padded != B_padded.shape[1]:
        raise RuntimeError(f"Padded dimensions mismatch: A_pad={A_padded.shape[1]}, B_pad={B_padded.shape[1]}")
    if N_padded == 0: return np.zeros((m, p_rows_B), dtype=A.dtype)

    k_actual = min(k, N_padded)
    if k_actual < k:
        warnings.warn(f"SRHT sampling k reduced from {k} to {k_actual} due to padded dim N={N_padded}", RuntimeWarning)
    if k_actual == 0: return np.zeros((m, p_rows_B), dtype=A.dtype)

    D_diag_A = np.random.choice([-1.0, 1.0], size=N_padded)
    D_diag_B = np.random.choice([-1.0, 1.0], size=N_padded) # Independent signs for B

    try:
        HA_unnorm = fast_walsh_hadamard_transform_manual(A_padded * D_diag_A, axis=1)
        HB_unnorm = fast_walsh_hadamard_transform_manual(B_padded * D_diag_B, axis=1) # Use D_diag_B
        HA = HA_unnorm / np.sqrt(N_padded)
        HB = HB_unnorm / np.sqrt(N_padded)
    except ValueError as ve:
        raise RuntimeError(f"Manual FWHT failed: {ve}. Shapes: A_pad={A_padded.shape}, B_pad={B_padded.shape}")
    except Exception as e:
        raise RuntimeError(f"Unexpected error in Manual FWHT: {e}\n{traceback.format_exc()}")

    sampled_indices_padded = np.random.choice(N_padded, size=k_actual, replace=False)
    scaling_factor = np.sqrt(N_padded / k_actual)

    A_reduced = HA[:, sampled_indices_padded] * scaling_factor
    B_reduced = HB[:, sampled_indices_padded] * scaling_factor

    return A_reduced @ B_reduced.T

# ==============================================================================
# Experiment 3 Runner Function
# ==============================================================================
def run_experiment_3_scalability(n_values: List[int], m_dim_exp: int, p_dim_exp: int, k_ratio_exp: float,
                                 n_trials_exp: int, base_seed_exp: int,
                                 matrix_dist_type_exp: str = 'gaussian') -> Dict[int, Dict]:
    experiment_data = {}
    current_seed_for_n_iter = base_seed_exp
    print(f"\n=== Starting Experiment 3: Scalability Analysis ===")
    print(f"Params: n_values={n_values}, m={m_dim_exp}, p={p_dim_exp}, k_ratio={k_ratio_exp}, trials={n_trials_exp}, dist={matrix_dist_type_exp.capitalize()}")

    for n_val_iter in tqdm(n_values, desc="Overall N Progress"):
        print(f"\n--- Processing n = {n_val_iter} (Dist: {matrix_dist_type_exp.capitalize()}) ---")
        k_val = int(n_val_iter * k_ratio_exp)
        if k_val <= 0:
            print(f"  Warning: Calculated k={k_val} for n={n_val_iter}. Skipping this n.")
            continue

        current_n_dim_iter = n_val_iter
        np.random.seed(current_seed_for_n_iter)
        A_orig, B_orig = generate_matrices(m_dim_exp, p_dim_exp, current_n_dim_iter, seed=None, distribution=matrix_dist_type_exp)

        algo_base_seed_this_n = current_seed_for_n_iter + 1
        current_seed_for_n_iter += (n_trials_exp + 10) # Increment seed for next n

        try:
            rho_G_val = calculate_rho_g(A_orig, B_orig)
            print(f"  Generated A({A_orig.shape}), B({B_orig.shape}). Rho_G = {rho_G_val:.4f}. k = {k_val}.")
        except Exception as e:
            print(f"  Error calculating Rho_G for n={current_n_dim_iter}: {e}. Setting Rho_G=NaN.")
            rho_G_val = np.nan

        print(f"  Calculating exact AB^T for n={current_n_dim_iter}...")
        try:
            A_f64, B_f64 = A_orig.astype(np.float64), B_orig.astype(np.float64)
            start_time_abt = time.time()
            ABt_exact = A_f64 @ B_f64.T
            time_abt = time.time() - start_time_abt
            frob_ABt_sq_exact = frob_norm_sq(ABt_exact)

            if frob_ABt_sq_exact < 1e-20:
                 print(f"  Warning: ||AB^T||_F^2 is near zero for n={current_n_dim_iter}. Results might be trivial.")
            print(f"  ||AB^T||_F^2 = {frob_ABt_sq_exact:.4e} (Computed in {time_abt:.3f}s)")
        except MemoryError:
            print(f"  MemoryError during exact AB^T calculation for n={current_n_dim_iter}. Stopping experiment for larger n."); break
        except Exception as e:
            print(f"  Error calculating exact AB^T for n={current_n_dim_iter}: {e}. Skipping this n."); continue

        all_results_n, all_times_n = {'Frob ABT Sq': frob_ABt_sq_exact}, {'AB^T': time_abt}

        print("  Computing theoretical bounds...")
        try:
            user_bounds_sq_ratios, user_bounds_timings = compute_theoretical_bounds(
                A_f64, B_f64, k_val, current_n_dim_iter, frob_ABt_sq_exact
            )
            all_results_n['Bound (QP Aux)'] = user_bounds_sq_ratios.get('Your Bound (QP CVXPY Best)', np.nan)
            all_results_n['Bound (Scaled Id)'] = user_bounds_sq_ratios.get('Your Bound (QP Analytical)', np.nan)
            all_results_n['Your Bound (Binary)'] = user_bounds_sq_ratios.get('Your Bound (Binary)', np.nan)

            all_times_n['Time Bound (QP Aux)'] = user_bounds_timings.get('time_qp_cvxpy_best', np.nan)
            all_times_n['Time Bound (Scaled Id)'] = user_bounds_timings.get('time_qp_analytical', np.nan)
            all_times_n['Time Your Bound (Binary)'] = user_bounds_timings.get('time_binary', np.nan)

        except Exception as e: print(f"    Error computing user bounds for k={k_val}, n={current_n_dim_iter}: {e}")

        try:
            std_bounds_ratios, std_bounds_timings = compute_standard_bounds(
                A_f64, B_f64, k_val, current_n_dim_iter, frob_ABt_sq_exact
            )
            all_results_n['Bound (Sampling)'] = std_bounds_ratios.get('Bound (Leverage Score Exp.)', np.nan)
            all_results_n['Bound (Sketching)'] = std_bounds_ratios.get('Bound (Sketching Simple)', np.nan)

            all_times_n['Time Bound (Sampling)'] = std_bounds_timings.get('time_lev_score_exp', np.nan)
            all_times_n['Time Bound (Sketching)'] = std_bounds_timings.get('time_sketch_simple', np.nan)

        except Exception as e: print(f"    Error computing standard bounds for k={k_val}, n={current_n_dim_iter}: {e}")

        print("  Running approximation algorithms...")
        algo_runners_map = {
            'Optimal Sampling': (lambda: run_leverage_score_sampling(A_f64, B_f64, k_val), 'Optimal Sampling'),
            'CountSketch':      (lambda: run_countsketch(A_f64, B_f64, k_val), 'CountSketch'),
            'SRHT':             (lambda: run_srht_new(A_f64, B_f64, k_val), 'SRHT'),
            'Gaussian Proj.':   (lambda: run_gaussian_projection(A_f64, B_f64, k_val), 'Gaussian Proj.'),
            'Greedy OMP':       (lambda: run_greedy_selection_omp(A_f64, B_f64, k_val, ABt_exact), 'Greedy OMP')
        }

        for result_key_name, (runner_func, time_key_name) in algo_runners_map.items():
            current_algo_trials_base_seed = algo_base_seed_this_n + sum(ord(c) for c in time_key_name)
            start_time_algo_run_total = time.time()
            try:
                if time_key_name not in ['Greedy OMP']: # Stochastic algorithms
                    errors_sq_trials_list = []
                    single_trial_times_list = []
                    for trial_idx in range(n_trials_exp):
                        np.random.seed(current_algo_trials_base_seed + trial_idx)
                        trial_start_time_sec = time.time()
                        approx_ABt_trial = runner_func()
                        single_trial_times_list.append(time.time() - trial_start_time_sec)
                        errors_sq_trials_list.append(frob_norm_sq(ABt_exact - approx_ABt_trial))

                    mean_error_sq = np.mean(errors_sq_trials_list) if errors_sq_trials_list else np.nan
                    all_results_n[result_key_name] = mean_error_sq / frob_ABt_sq_exact if frob_ABt_sq_exact > 1e-20 else 0.0
                    all_times_n[time_key_name] = np.mean(single_trial_times_list) if single_trial_times_list else np.nan
                else: # Deterministic algorithms
                    np.random.seed(current_algo_trials_base_seed) # Still set seed for consistency if OMP had any randomness (it doesn't here)
                    approx_ABt_single = runner_func()
                    all_results_n[result_key_name] = frob_norm_sq(ABt_exact - approx_ABt_single) / frob_ABt_sq_exact if frob_ABt_sq_exact > 1e-20 else 0.0
                    all_times_n[time_key_name] = time.time() - start_time_algo_run_total
            except ValueError as ve: print(f"    VALUE ERROR in {time_key_name} (k={k_val}, n={current_n_dim_iter}): {ve}"); all_results_n[result_key_name] = np.nan; all_times_n[time_key_name] = np.nan
            except Exception as e: print(f"    UNEXPECTED ERROR in {time_key_name} (k={k_val}, n={current_n_dim_iter}): {e}"); traceback.print_exc(); all_results_n[result_key_name] = np.nan; all_times_n[time_key_name] = np.nan

        for res_key in all_results_n:
             if res_key not in ['Frob ABT Sq'] and isinstance(all_results_n[res_key], (float, np.float64)):
                 if not np.isnan(all_results_n[res_key]):
                     all_results_n[res_key] = np.clip(all_results_n[res_key], 0, 10.0) # Clip relative squared errors/bounds

        experiment_data[current_n_dim_iter] = {
            'k': k_val, 'rho_G': rho_G_val, 'matrix_dist_type': matrix_dist_type_exp,
            'results': all_results_n, 'times': all_times_n,
            'm_dim': m_dim_exp, 'p_dim': p_dim_exp
        }
    print(f"\n=== Experiment 3 (Dist: {matrix_dist_type_exp.capitalize()}) Data Collection Finished. ===")
    return experiment_data

# ==============================================================================
# Plotting Function for Experiment 3
# ==============================================================================
def plot_experiment_3_scalability(exp_results_data: Dict[int, Dict],
                                  styles_to_use: Dict,
                                  plot_filename_prefix: str = "Exp3_Scalability"):
    if not exp_results_data: print("No results to plot for Experiment 3."); return

    n_values_sorted_plot = sorted(exp_results_data.keys())
    if not n_values_sorted_plot: print("No n values found in results data."); return

    first_n_res = exp_results_data.get(n_values_sorted_plot[0], {})
    matrix_dist_plot = first_n_res.get('matrix_dist_type', 'UnknownDist').capitalize()
    m_dim_plot = first_n_res.get('m_dim', 'M')
    p_dim_plot = first_n_res.get('p_dim', 'P')
    k_val_first_n = first_n_res.get('k', np.nan)
    k_ratio_plot_num = k_val_first_n / n_values_sorted_plot[0] if n_values_sorted_plot[0] > 0 and not np.isnan(k_val_first_n) else np.nan
    k_ratio_plot_str = f"{k_ratio_plot_num:.2f}" if not np.isnan(k_ratio_plot_num) else 'UnknownKRatio'

    defined_algo_result_keys = ['Optimal Sampling', 'CountSketch', 'SRHT', 'Gaussian Proj.', 'Greedy OMP']
    defined_bound_result_keys = ['Bound (QP Aux)', 'Bound (Scaled Id)', 'Your Bound (Binary)', 'Bound (Sampling)', 'Bound (Sketching)']
    all_value_plot_keys = defined_algo_result_keys + defined_bound_result_keys

    plot_data_struct = {key: {'value_or_error': [], 'time': []} for key in all_value_plot_keys}
    rho_G_values_for_plot, valid_n_coords_plot = [], []

    for n_coord in n_values_sorted_plot:
        if n_coord not in exp_results_data or not exp_results_data[n_coord].get('results'): continue
        valid_n_coords_plot.append(n_coord)

        res_at_n = exp_results_data[n_coord]['results']
        time_at_n = exp_results_data[n_coord]['times']
        rho_G_values_for_plot.append(exp_results_data[n_coord].get('rho_G', np.nan))

        for key_style in all_value_plot_keys:
            if key_style not in plot_data_struct:
                warnings.warn(f"Key '{key_style}' not initialized in plot_data_struct. Skipping.", UserWarning)
                continue
            plot_data_struct[key_style]['value_or_error'].append(res_at_n.get(key_style, np.nan))

            time_data_key = np.nan
            if key_style in defined_algo_result_keys:
                time_data_key = key_style
            elif key_style in defined_bound_result_keys:
                time_data_key = f"Time {key_style}"
            plot_data_struct[key_style]['time'].append(time_at_n.get(time_data_key, np.nan))

    if not valid_n_coords_plot: print("No valid data points to plot after processing."); return
    valid_n_coords_plot_np = np.array(valid_n_coords_plot)

    fig_main, (ax_values_plot, ax_times_plot) = plt.subplots(1, 2, figsize=plt.rcParams['figure.figsize'], sharex=True)

    for key_to_plot in all_value_plot_keys:
        style_params = styles_to_use.get(key_to_plot)
        if not style_params:
            warnings.warn(f"Style not found for '{key_to_plot}'. Skipping in plot.", UserWarning); continue

        y_values_curr = np.array(plot_data_struct[key_to_plot]['value_or_error'])
        valid_mask_val_plot = ~np.isnan(y_values_curr) & (y_values_curr >= 1e-12) & np.isfinite(y_values_curr)
        if np.any(valid_mask_val_plot):
            ax_values_plot.plot(valid_n_coords_plot_np[valid_mask_val_plot], y_values_curr[valid_mask_val_plot],
                                **{k:v for k,v in style_params.items() if k != 'label'}) # No legend on this plot

        y_times_curr = np.array(plot_data_struct[key_to_plot]['time'])
        valid_mask_time_plot = ~np.isnan(y_times_curr) & (y_times_curr > 1e-9) & np.isfinite(y_times_curr)
        if np.any(valid_mask_time_plot):
            ax_times_plot.plot(valid_n_coords_plot_np[valid_mask_time_plot], y_times_curr[valid_mask_time_plot],
                               **{k:v for k,v in style_params.items() if k != 'label'}) # No legend on this plot

    ax_values_plot.set_title("Algorithm Errors & Bound Values vs. n")
    ax_values_plot.set_ylabel("Relative Sq. Error / Bound Value (Log)")
    ax_values_plot.set_yscale('log'); ax_values_plot.set_xscale('log')
    ax_values_plot.grid(True, which='both', linestyle=':', linewidth=plt.rcParams['grid.linewidth'], alpha=plt.rcParams['grid.alpha'])
    ax_values_plot.spines['top'].set_visible(False); ax_values_plot.spines['right'].set_visible(False)
    ax_values_plot.set_xlabel("Dimension n (Log Scale)")

    ax_times_plot.set_title("Computation Time vs. n")
    ax_times_plot.set_ylabel("Time (seconds, Log Scale)")
    ax_times_plot.set_yscale('log');
    ax_times_plot.grid(True, which='both', linestyle=':', linewidth=plt.rcParams['grid.linewidth'], alpha=plt.rcParams['grid.alpha'])
    ax_times_plot.spines['top'].set_visible(False); ax_times_plot.spines['right'].set_visible(False)
    ax_times_plot.set_xlabel("Dimension n (Log Scale)")

    fig_main.suptitle(f"Scalability Analysis (Dist: {matrix_dist_plot}, m={m_dim_plot}, p={p_dim_plot}, k/n ≈ {k_ratio_plot_str})", y=1.0)
    plt.tight_layout(rect=[0.02, 0.03, 0.98, 0.95])
    plt.show() # Show plot in notebook

    main_plot_full_filename = os.path.join("plots", f"{plot_filename_prefix}_{matrix_dist_plot}_m{m_dim_plot}_p{p_dim_plot}_kratio{k_ratio_plot_str.replace('.','pt')}_main_nolegend.png")
    try:
        fig_main.savefig(main_plot_full_filename, bbox_inches='tight')
        print(f"Saved main plot: {main_plot_full_filename}")
    except Exception as e: print(f"Error saving main plot {main_plot_full_filename}: {e}")
    plt.close(fig_main)

    rho_G_values_for_plot_np = np.array(rho_G_values_for_plot)
    valid_rho_mask_plot = ~np.isnan(rho_G_values_for_plot_np) & np.isfinite(rho_G_values_for_plot_np)
    if np.any(valid_rho_mask_plot):
        fig_rho_plot, ax_rho_plot = plt.subplots(figsize=(10, 6))
        ax_rho_plot.plot(valid_n_coords_plot_np[valid_rho_mask_plot], rho_G_values_for_plot_np[valid_rho_mask_plot],
                         marker='o', linestyle='-', color='purple', lw=2.8, markersize=10, mfc='mediumorchid', mec='purple', alpha=0.8)
        ax_rho_plot.set_xlabel("Dimension n (Log Scale)")
        ax_rho_plot.set_ylabel("Calculated $\\rho_G(A, B)$")
        ax_rho_plot.set_title(f"Observed $\\rho_G$ vs. n (Dist: {matrix_dist_plot})")
        ax_rho_plot.set_xscale('log')
        ax_rho_plot.grid(True, which='both', linestyle=':', linewidth=plt.rcParams['grid.linewidth'], alpha=plt.rcParams['grid.alpha'])
        ax_rho_plot.spines['top'].set_visible(False); ax_rho_plot.spines['right'].set_visible(False)
        fig_rho_plot.tight_layout()
        plt.show() # Show plot in notebook

        rho_plot_full_filename = os.path.join("plots", f"{plot_filename_prefix}_{matrix_dist_plot}_m{m_dim_plot}_p{p_dim_plot}_kratio{k_ratio_plot_str.replace('.','pt')}_rho_G.png")
        try:
            fig_rho_plot.savefig(rho_plot_full_filename, bbox_inches='tight')
            print(f"Saved Rho_G plot: {rho_plot_full_filename}")
        except Exception as e: print(f"Error saving Rho_G plot {rho_plot_full_filename}: {e}")
        finally: plt.close(fig_rho_plot)
    else:
        print("Skipping Rho_G plot: no valid Rho_G values were collected or available.")